"""
Search Cache Utility
Provides intelligent caching for search results to avoid repeated queries.
Includes TTL, LRU eviction, and query normalization.
"""

import hashlib
import threading
import time
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional

from loguru import logger
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from ..config.paths import get_cache_directory
from ..database.models import Base, SearchCache as SearchCacheModel


class SearchCache:
    """
    Persistent cache for search results with TTL and LRU eviction.
    Stores results in SQLite for persistence across sessions.
    """

    def __init__(
        self,
        cache_dir: str = None,
        max_memory_items: int = 1000,
        default_ttl: int = 3600,
    ):
        """
        Initialize search cache.

        Args:
            cache_dir: Directory for cache database. Defaults to data/__CACHE_DIR__
            max_memory_items: Maximum items in memory cache
            default_ttl: Default time-to-live in seconds (1 hour default)
        """
        self.max_memory_items = max_memory_items
        self.default_ttl = default_ttl

        # Setup cache directory
        if cache_dir is None:
            cache_dir = get_cache_directory() / "search_cache"
        else:
            cache_dir = Path(cache_dir)

        cache_dir.mkdir(parents=True, exist_ok=True)
        self.db_path = cache_dir / "search_cache.db"

        # Initialize database
        self._init_db()

        # In-memory cache for frequently accessed items
        self._memory_cache = {}
        self._access_times = {}

        # Stampede protection: events and locks for each query being fetched
        self._fetch_events = {}  # query_hash -> threading.Event (signals completion)
        self._fetch_locks = {}  # query_hash -> threading.Lock (prevents concurrent fetch)
        self._fetch_locks_lock = (
            threading.Lock()
        )  # Protects the fetch dictionaries
        self._fetch_results = {}  # query_hash -> results (temporary storage during fetch)

    def _init_db(self):
        """Initialize SQLite database for persistent cache using SQLAlchemy."""
        try:
            # Create engine and session
            self.engine = create_engine(f"sqlite:///{self.db_path}")
            Base.metadata.create_all(
                self.engine, tables=[SearchCacheModel.__table__]
            )
            self.Session = sessionmaker(bind=self.engine)
        except Exception:
            logger.exception("Failed to initialize search cache database")

    def _normalize_query(self, query: str) -> str:
        """Normalize query for consistent caching."""
        # Convert to lowercase and remove extra whitespace
        normalized = " ".join(query.lower().strip().split())

        # Remove common punctuation that doesn't affect search
        normalized = normalized.replace('"', "").replace("'", "")

        return normalized

    def _get_query_hash(
        self, query: str, search_engine: str = "default"
    ) -> str:
        """Generate hash for query + search engine combination."""
        normalized_query = self._normalize_query(query)
        cache_key = f"{search_engine}:{normalized_query}"
        return hashlib.md5(  # DevSkim: ignore DS126858
            cache_key.encode(), usedforsecurity=False
        ).hexdigest()

    def _cleanup_expired(self):
        """Remove expired entries from database."""
        try:
            current_time = int(time.time())
            with self.Session() as session:
                deleted = (
                    session.query(SearchCacheModel)
                    .filter(SearchCacheModel.expires_at < current_time)
                    .delete()
                )
                session.commit()
                if deleted > 0:
                    logger.debug(f"Cleaned up {deleted} expired cache entries")
        except Exception:
            logger.exception("Failed to cleanup expired cache entries")

    def _evict_lru_memory(self):
        """Evict least recently used items from memory cache."""
        if len(self._memory_cache) <= self.max_memory_items:
            return

        # Sort by access time and remove oldest
        sorted_items = sorted(self._access_times.items(), key=lambda x: x[1])
        items_to_remove = (
            len(self._memory_cache) - self.max_memory_items + 100
        )  # Remove extra for efficiency

        for query_hash, _ in sorted_items[:items_to_remove]:
            self._memory_cache.pop(query_hash, None)
            self._access_times.pop(query_hash, None)

    def get(
        self, query: str, search_engine: str = "default"
    ) -> Optional[List[Dict[str, Any]]]:
        """
        Get cached search results for a query.

        Args:
            query: Search query
            search_engine: Search engine identifier for cache partitioning

        Returns:
            Cached results or None if not found/expired
        """
        query_hash = self._get_query_hash(query, search_engine)
        current_time = int(time.time())

        # Check memory cache first
        if query_hash in self._memory_cache:
            entry = self._memory_cache[query_hash]
            if entry["expires_at"] > current_time:
                self._access_times[query_hash] = current_time
                logger.debug(f"Cache hit (memory) for query: {query[:50]}...")
                return entry["results"]
            else:
                # Expired, remove from memory
                self._memory_cache.pop(query_hash, None)
                self._access_times.pop(query_hash, None)

        # Check database cache
        try:
            with self.Session() as session:
                cache_entry = (
                    session.query(SearchCacheModel)
                    .filter(
                        SearchCacheModel.query_hash == query_hash,
                        SearchCacheModel.expires_at > current_time,
                    )
                    .first()
                )

                if cache_entry:
                    results = cache_entry.results

                    # Update access statistics
                    cache_entry.access_count += 1
                    cache_entry.last_accessed = current_time
                    session.commit()

                    # Add to memory cache
                    self._memory_cache[query_hash] = {
                        "results": results,
                        "expires_at": cache_entry.expires_at,
                    }
                    self._access_times[query_hash] = current_time
                    self._evict_lru_memory()

                    logger.debug(
                        f"Cache hit (database) for query: {query[:50]}..."
                    )
                    return results

        except Exception:
            logger.exception("Failed to retrieve from search cache")

        logger.debug(f"Cache miss for query: {query[:50]}...")
        return None

    def put(
        self,
        query: str,
        results: List[Dict[str, Any]],
        search_engine: str = "default",
        ttl: Optional[int] = None,
    ) -> bool:
        """
        Store search results in cache.

        Args:
            query: Search query
            results: Search results to cache
            search_engine: Search engine identifier
            ttl: Time-to-live in seconds (uses default if None)

        Returns:
            True if successfully cached
        """
        if not results:  # Don't cache empty results
            return False

        query_hash = self._get_query_hash(query, search_engine)
        current_time = int(time.time())
        expires_at = current_time + (ttl or self.default_ttl)

        try:
            # Store in database
            with self.Session() as session:
                # Check if entry exists
                existing = (
                    session.query(SearchCacheModel)
                    .filter_by(query_hash=query_hash)
                    .first()
                )

                if existing:
                    # Update existing entry
                    existing.query_text = self._normalize_query(query)
                    existing.results = results
                    existing.created_at = current_time
                    existing.expires_at = expires_at
                    existing.access_count = 1
                    existing.last_accessed = current_time
                else:
                    # Create new entry
                    cache_entry = SearchCacheModel(
                        query_hash=query_hash,
                        query_text=self._normalize_query(query),
                        results=results,
                        created_at=current_time,
                        expires_at=expires_at,
                        access_count=1,
                        last_accessed=current_time,
                    )
                    session.add(cache_entry)

                session.commit()

            # Store in memory cache
            self._memory_cache[query_hash] = {
                "results": results,
                "expires_at": expires_at,
            }
            self._access_times[query_hash] = current_time
            self._evict_lru_memory()

            logger.debug(f"Cached results for query: {query[:50]}...")
            return True

        except Exception:
            logger.exception("Failed to store in search cache")
            return False

    def get_or_fetch(
        self,
        query: str,
        fetch_func,
        search_engine: str = "default",
        ttl: Optional[int] = None,
    ) -> Optional[List[Dict[str, Any]]]:
        """
        Get cached results or fetch with stampede protection.

        This is the recommended way to use the cache. It ensures only one thread
        fetches data for a given query, preventing cache stampedes.

        Args:
            query: Search query
            fetch_func: Function to call if cache miss (should return results list)
            search_engine: Search engine identifier
            ttl: Time-to-live for cached results

        Returns:
            Search results (from cache or freshly fetched)
        """
        query_hash = self._get_query_hash(query, search_engine)

        # Try to get from cache first
        results = self.get(query, search_engine)
        if results is not None:
            return results

        # Acquire lock for this query to prevent stampede
        with self._fetch_locks_lock:
            # Double-check after acquiring lock
            results = self.get(query, search_engine)
            if results is not None:
                return results

            # Check if another thread started fetching while we waited
            if query_hash in self._fetch_events:
                existing_event = self._fetch_events[query_hash]
                # Check if this is a stale event (already set means fetch completed)
                if existing_event.is_set():
                    # Previous fetch completed, clean up and start fresh
                    del self._fetch_events[query_hash]
                    del self._fetch_locks[query_hash]
                    if query_hash in self._fetch_results:
                        del self._fetch_results[query_hash]
                    # Create new event/lock for this fetch
                    event = threading.Event()
                    self._fetch_events[query_hash] = event
                    self._fetch_locks[query_hash] = threading.Lock()
                    event = None  # Signal we should fetch
                else:
                    # Another thread is actively fetching
                    event = existing_event
            else:
                # We are the first thread to fetch this query
                event = threading.Event()
                self._fetch_events[query_hash] = event
                self._fetch_locks[query_hash] = threading.Lock()
                event = None  # Signal we should fetch

        # If another thread is fetching, wait for it
        if event is not None:
            event.wait(timeout=30)
            if query_hash in self._fetch_results:
                result = self._fetch_results.get(query_hash)
                if result is not None:
                    return result
            # Re-check cache, and if still miss, return None (fetch failed)
            return self.get(query, search_engine)

        # We are the thread that should fetch
        fetch_lock = self._fetch_locks[query_hash]
        fetch_event = self._fetch_events[query_hash]

        with fetch_lock:
            # Triple-check (another thread might have fetched while we waited for lock)
            results = self.get(query, search_engine)
            if results is not None:
                fetch_event.set()  # Signal completion
                return results

            logger.debug(
                f"Fetching results for query: {query[:50]}... (stampede protected)"
            )

            try:
                # Fetch the results
                results = fetch_func()

                if results:
                    # Store in cache
                    self.put(query, results, search_engine, ttl)

                    # Store temporarily for other waiting threads
                    self._fetch_results[query_hash] = results

                return results

            except Exception:
                logger.exception(
                    f"Failed to fetch results for query: {query[:50]}"
                )
                # Store None to indicate fetch failed
                self._fetch_results[query_hash] = None
                return None

            finally:
                # Signal completion to waiting threads
                fetch_event.set()

                # Clean up after a delay
                def cleanup():
                    time.sleep(2)  # Give waiting threads time to get results
                    with self._fetch_locks_lock:
                        self._fetch_locks.pop(query_hash, None)
                        self._fetch_events.pop(query_hash, None)
                        self._fetch_results.pop(query_hash, None)

                # Run cleanup in background
                threading.Thread(target=cleanup, daemon=True).start()

    def invalidate(self, query: str, search_engine: str = "default") -> bool:
        """Invalidate cached results for a specific query."""
        query_hash = self._get_query_hash(query, search_engine)

        try:
            # Remove from memory
            self._memory_cache.pop(query_hash, None)
            self._access_times.pop(query_hash, None)

            # Remove from database
            with self.Session() as session:
                deleted = (
                    session.query(SearchCacheModel)
                    .filter_by(query_hash=query_hash)
                    .delete()
                )
                session.commit()

            logger.debug(f"Invalidated cache for query: {query[:50]}...")
            return deleted > 0

        except Exception:
            logger.exception("Failed to invalidate cache")
            return False

    def clear_all(self) -> bool:
        """Clear all cached results."""
        try:
            self._memory_cache.clear()
            self._access_times.clear()

            with self.Session() as session:
                session.query(SearchCacheModel).delete()
                session.commit()

            logger.info("Cleared all search cache")
            return True

        except Exception:
            logger.exception("Failed to clear search cache")
            return False

    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics."""
        try:
            current_time = int(time.time())
            with self.Session() as session:
                # Total entries
                total_entries = (
                    session.query(SearchCacheModel)
                    .filter(SearchCacheModel.expires_at > current_time)
                    .count()
                )

                # Total expired entries
                expired_entries = (
                    session.query(SearchCacheModel)
                    .filter(SearchCacheModel.expires_at <= current_time)
                    .count()
                )

                # Average access count
                from sqlalchemy import func

                avg_access_result = (
                    session.query(func.avg(SearchCacheModel.access_count))
                    .filter(SearchCacheModel.expires_at > current_time)
                    .scalar()
                )
                avg_access = avg_access_result or 0

            return {
                "total_valid_entries": total_entries,
                "expired_entries": expired_entries,
                "memory_cache_size": len(self._memory_cache),
                "average_access_count": round(avg_access, 2),
                "cache_hit_potential": (
                    f"{(total_entries / (total_entries + 1)) * 100:.1f}%"
                    if total_entries > 0
                    else "0%"
                ),
            }

        except Exception:
            logger.exception("Failed to get cache stats")
            return {"error": "Cache stats unavailable"}


# Global cache instance
_global_cache = None


def get_search_cache() -> SearchCache:
    """Get global search cache instance."""
    global _global_cache
    if _global_cache is None:
        _global_cache = SearchCache()
    return _global_cache


@lru_cache(maxsize=100)
def normalize_entity_query(entity: str, constraint: str) -> str:
    """
    Normalize entity + constraint combination for consistent caching.
    Uses LRU cache for frequent normalizations.
    """
    # Remove quotes and normalize whitespace
    entity_clean = " ".join(entity.strip().lower().split())
    constraint_clean = " ".join(constraint.strip().lower().split())

    # Create canonical form
    return f"{entity_clean} {constraint_clean}"
